
import torch
import numpy as np
import device
#device = "cuda" if torch.cuda.is_available() else "cpu"
import sys

import matplotlib.pyplot as plt

from channel import *
from neuralEQRNN import *



class simNeuralEQRNN():
	def __init__(self, txDataTrain, rxDataTrain, txDataTest, rxDataTest, neuralEQ):
		self.txDataTrain = txDataTrain
		self.rxDataTrain = rxDataTrain
		self.txDataTest = txDataTest
		self.rxDataTest = rxDataTest
		self.neuralEQ = neuralEQ
		self.dataSize = len(txDataTest)

	def curatingData(self, rxData, txData, seqLength, batchSize):
		## seqLength is seqLength of RNN input
		rxDataBatch = []
		txDataBatch = []
		rxDataSet = []
		txDataSet = []
		rxData = list(rxData)
		txData = list(txData)
		dataSegments = int(len(rxData)/seqLength)
		for k in range(dataSegments-1):
			rxSegment = rxData[k*seqLength:k*seqLength+seqLength]
			txSegment = txData[k*seqLength:k*seqLength+seqLength]
			for i in range(len(rxSegment)):
				rxSegment[i] = [rxSegment[i]]
			for i in range(len(txSegment)):
				txSegment[i] = [txSegment[i]]
			#print (rxSegment)
			#print (txSegment)
			rxDataBatch.append(rxSegment)
			txDataBatch.append(txSegment)
			if (k%batchSize == batchSize-1):
				rxDataSet.append(rxDataBatch)
				txDataSet.append(txDataBatch)
				rxDataBatch = []
				txDataBatch = []
				#sys.exit()

		#print ((rxDataSet))
		#print ((txDataSet))

		return torch.Tensor(rxDataSet), torch.Tensor(txDataSet)

	def trainNeuralEQ(self, lossFn, opt, seqLength=8,  batchSize=64, plot=False): #list(self.chSBR).index(max(self.chSBR)) # plot=False
		## rxData size must be same or larger than txData size.

		lossList = []
		batchIdxList = []
		self.neuralEQ.train()

		(rxDataSet, txDataSet) = self.curatingData(self.rxDataTrain, self.txDataTrain, seqLength, batchSize)

		size = len(rxDataSet)
		#print(rxDataSet.shape)
		#print(rxDataSet)
		be = 0
		beOld = 0
		#hPrev=torch.zeros(2,batchSize, 5, requires_grad=True)
		#hPrev = hPrev.to(device.device)
		hPrev = None

		for batchIdx, tmp in enumerate(rxDataSet):
			#for name, param in self.neuralEQ.named_parameters():
			#	print(f"name: {name} params:\n{param}")
			x = rxDataSet[batchIdx]
			y = txDataSet[batchIdx]
			x = x.to(device.device)
			y = y.to(device.device)
			#print (x.shape)
			#if (lossFn is torch.nn.CrossEntropyLoss):
			#	print(lossFn)

			
			pred, h = self.neuralEQ(x, hPrev)
			hPrev = h.detach()
			#pred=torch.flatten(pred)
			#y=torch.flatten(y)
			if (batchIdx !=0):
				loss = lossFn(pred, y)
				#print (pred)
				#print (y)

				opt.zero_grad()
				loss.backward()
				opt.step()

			#predClone = pred.clone().detach()
			#predClone = predClone.numpy()
			#predClone = torch.where(predClone>0.5, 1, 0)
			#predClone = torch.tensor(predClone)
			#print (type(predClone))
			if (batchIdx != 0):
				#print(f"x: {x[0]}\ny: {y[0]}\npred: {pred[0]}\n beOld: {beOld}, be: {be}")
				be += sum(sum(sum(abs(torch.sign(pred) - torch.sign(y))/2)))
				if (plot):
					if (beOld != be):
						print(f"batchIdx: {batchIdx}")
						print(f"x: {x}\ny: {y}\npred: {pred}\n beOld: {beOld}, be: {be}")
				beOld = be
			#sys.exit()

			#if (batchIdx % 1000 == 0):
			#	loss, current = loss.item(), batchIdx
			#	lossList.append(loss)
			#	batchIdxList.append(batchIdx)
			#	print(f"trainloss: {loss:>7f} [{current:>5d}/{size:>5d}]")
		ber = float(be)/(self.dataSize-batchSize)
		#print (pred[0])
		#print (np.sign(pred[0].detach().numpy()))
		#print (y[0])
		#for name, param in self.neuralEQ.named_parameters():
		#	print(f"name: {name} params:\n{param}")

		#if (plot):
		#	plt.figure(0)
		#	plt.plot(batchIdxList, lossList)
		#	plt.grid(True)
		#	plt.yscale('log')
		#	plt.ylim([0.000000001, 1])
		#	#plt.show()
		#	plt.savefig('train.png')

		return loss, ber


	def evalNeuralEQ(self, lossFn, seqLength=8,  batchSize=64, rxDataTestNew=None, txDataTestNew=None ):
		self.neuralEQ.eval()

		if (rxDataTestNew is not None):
			(rxDataSet, txDataSet) = self.curatingData(rxDataTestNew, txDataTestNew, seqLength,  batchSize)
			datalen = len(txDataTestNew)

		else:
			(rxDataSet, txDataSet) = self.curatingData(self.rxDataTest, self.txDataTest, seqLength,  batchSize)
			datalen = self.dataSize

		size = len(rxDataSet)

		testLoss, be = 0, 0
		hPrev = None

		with torch.no_grad():
			for batchIdx, tmp in enumerate(rxDataSet):
				x = rxDataSet[batchIdx]
				y = txDataSet[batchIdx]
				x = x.to(device.device)
				y = y.to(device.device)

				#pred = self.neuralEQ(x)
				pred, h = self.neuralEQ(x, hPrev)
				hPrev = h.detach()
				#pred=torch.flatten(pred)
				#y=torch.flatten(y)
				testLoss += lossFn(pred, y).item()
				#print(f"x: {x}\ny: {y}\npred: {pred}")
				#predClone = pred.clone().detach()
				#predClone = predClone.numpy()
				#predClone = torch.where(pred>0.5, 1, 0)
				#predClone = torch.tensor(predClone)
				#print (type(predClone))


				#print (pred)
				#print (y)
				#print (np.sign(pred) - np.sign(y))
				#print (sum(sum(sum(abs(np.sign(pred) - np.sign(y))))))
				#if (batchIdx==size-1):
				#	print (pred)

				if (batchIdx !=0):
					be += sum(sum(sum(abs(torch.sign(pred) - torch.sign(y))/2)))

		testLoss /= batchIdx
		ber = float(be)/(datalen-batchSize)

		#print(f"testloss: {testLoss:>8f}, bit err: {be}, ber: {ber}\n")

		return testLoss, ber


if __name__ == '__main__':
	dataSizeTrain=int(1e4)
	dataSizeTest=int(1e5)
	dataSizeTestFinal=int(1e6)
	chSBR = [1.0,0.5,0.4,0.2]
	batchSize = 100
	seqLength = 10
	#chSBR = [1.0,0.4,0.2,0.1]
	#batchSize = 40 
	#inSize = 4 
	#outSize = 1
	#delay = 2
	snrTrain=10
	snrTest=10
	snrTestFinal=snrTest
	flagN = 0
	numEpoch = 1000
	hiddenSize = 10
	numLayers = 2
	bidir = False


	#############################
	#### Train sequence gen #####
	#############################
	chInTrain = np.array([],dtype=np.int)
	chInTrain = np.append(chInTrain, np.random.randint(2, size=dataSizeTrain))
	chInTrain = 2 * chInTrain - 1
	ch = Channel(sbr=chSBR, snr=snrTrain)
	chOutTrain = ch.run(chIn = chInTrain, flagN=flagN)


	#########################################
	#### Test sequence for on training  #####
	#########################################
	chInTest = np.array([],dtype=np.int)
	chInTest = np.append(chInTest, np.random.randint(2, size=dataSizeTest))
	chInTest = 2 * chInTest - 1
	ch2 = Channel(sbr=chSBR, snr=snrTest)
	chOutTest = ch2.run(chIn = chInTest, flagN=flagN)
	
	
	#############################################
	#### Test sequence for final evaluation #####
	#############################################
	chInTestFinal = np.array([],dtype=np.int)
	chInTestFinal = np.append(chInTestFinal, np.random.randint(2, size=dataSizeTestFinal))
	chInTestFinal = 2 * chInTestFinal - 1
	ch3 = Channel(sbr=chSBR, snr=snrTestFinal)
	chOutTestFinal = ch3.run(chIn = chInTestFinal, flagN=flagN)
	
	
	nEQRNN = neuralEQRNN(hiddenSize=hiddenSize, numLayers=numLayers, bidir=bidir)
	nEQRNN = nEQRNN.to(device.device)
	lossFn = nn.MSELoss()
	opt = torch.optim.Adam(nEQRNN.parameters(), lr=1e-2)

	simNEQRNN = simNeuralEQRNN(txDataTrain=chInTrain, rxDataTrain=chOutTrain, txDataTest=chInTest, rxDataTest=chOutTest, neuralEQ=nEQRNN)

	for k in range(numEpoch):
		loss, berTrain = simNEQRNN.trainNeuralEQ(lossFn, opt, batchSize=batchSize, seqLength=seqLength)
		if (k%10==0):
			print(f"trainloss: {loss}, trainber: {berTrain},  epoch:{k}/{numEpoch}", flush=True)
	loss, berTrain = simNEQRNN.trainNeuralEQ(lossFn, opt, batchSize=batchSize, seqLength=seqLength, plot=True)
	loss, berTrain = simNEQRNN.evalNeuralEQ(lossFn, batchSize=batchSize, seqLength=seqLength)
	print(f"testloss: {loss}, testber: {berTrain},  epoch:{k}/{numEpoch}", flush=True)
